"""
Phase 8: Post-QE Prediction and Re-emergence
Tests whether demographic effects on rates re-emerge after QE unwind (2022-2024),
out-of-sample prediction using pre-GFC coefficients, and rolling re-emergence.

Outputs:
  - Table 16: 2022-2024 regressions (phase8_table16_post_qe.md)
  - Table 17: Out-of-sample test (phase8_table17_oos.md)
  - Table 18: Rolling re-emergence (phase8_table18_reemergence.md + CSV)
  - Table 19: Country-level predicted vs actual (phase8_table19_country.md)
  - Figure: Predicted vs actual scatter (output/figures/phase8_predicted_vs_actual.png)
  - Figure: Natural rate projection (output/figures/phase8_natural_rate_projection.png)
"""

import sys
from pathlib import Path
import numpy as np
import pandas as pd

# ── paths ────────────────────────────────────────────────────────────────────
PROJECT = Path("/mnt/c/demographics_capital_flows")
sys.path.insert(0, str(PROJECT / "multilateral" / "src"))
from model import PanelGLS

DATA_PATH = PROJECT / "monetary" / "data" / "processed" / "monetary_panel.csv"
FULL_PANEL_PATH = PROJECT / "multilateral" / "followup" / "data" / "processed" / "full_panel.csv"
TABLE_DIR = PROJECT / "monetary" / "output" / "tables"
FIG_DIR = PROJECT / "monetary" / "output" / "figures"
TABLE_DIR.mkdir(parents=True, exist_ok=True)
FIG_DIR.mkdir(parents=True, exist_ok=True)

OECD_38 = [
    "AUS", "AUT", "BEL", "CAN", "CHL", "COL", "CRI", "CZE", "DNK", "EST",
    "FIN", "FRA", "DEU", "GRC", "HUN", "ISL", "IRL", "ISR", "ITA", "JPN",
    "KOR", "LVA", "LTU", "LUX", "MEX", "NLD", "NZL", "NOR", "POL", "PRT",
    "SVK", "SVN", "ESP", "SWE", "CHE", "TUR", "GBR", "USA",
]


# ── helpers ──────────────────────────────────────────────────────────────────
def run_model(panel, dep_var, rhs_vars, min_obs=50):
    """Run PanelGLS; return dict with coefs/se/pvals/r2/nobs or None."""
    cols = [dep_var] + rhs_vars + ["iso3", "year"]
    df = panel[cols].dropna()
    if len(df) < min_obs:
        print(f"  SKIP {dep_var}: only {len(df)} obs after dropna (need {min_obs})")
        return None
    try:
        y = df[dep_var].values
        X = df[rhs_vars].values
        gls = PanelGLS()
        gls.fit(y, X, df["iso3"].values, df["year"].values)
        return {
            "coefs": dict(zip(rhs_vars, gls.beta)),
            "se": dict(zip(rhs_vars, gls.se)),
            "pvals": dict(zip(rhs_vars, gls.pvalues)),
            "r2": gls.r_squared,
            "nobs": gls.n_obs,
            "ncountries": gls.n_countries,
        }
    except Exception as e:
        print(f"  ERROR {dep_var}: {e}")
        return None


def run_ols(panel, dep_var, rhs_vars):
    """Fallback OLS via numpy lstsq; return dict or None."""
    cols = [dep_var] + rhs_vars + ["iso3", "year"]
    df = panel[cols].dropna()
    if len(df) < 10:
        print(f"  SKIP OLS {dep_var}: only {len(df)} obs")
        return None
    try:
        y = df[dep_var].values
        X = np.column_stack([df[v].values for v in rhs_vars])
        X_with_const = np.column_stack([np.ones(len(y)), X])
        beta, residuals, rank, sv = np.linalg.lstsq(X_with_const, y, rcond=None)
        y_hat = X_with_const @ beta
        resid = y - y_hat
        n, k = X_with_const.shape
        s2 = np.sum(resid ** 2) / max(n - k, 1)
        try:
            var_beta = s2 * np.linalg.inv(X_with_const.T @ X_with_const)
            se = np.sqrt(np.diag(var_beta))
        except np.linalg.LinAlgError:
            se = np.full(k, np.nan)
        # t-stats and p-values
        from scipy import stats as sp_stats
        t_stats = beta / np.where(se > 0, se, np.nan)
        pvals = 2 * (1 - sp_stats.t.cdf(np.abs(t_stats), max(n - k, 1)))
        # R-squared
        ss_res = np.sum(resid ** 2)
        ss_tot = np.sum((y - np.mean(y)) ** 2)
        r2 = 1 - ss_res / ss_tot if ss_tot > 0 else np.nan
        # Drop constant from output
        beta_vars = beta[1:]
        se_vars = se[1:]
        pvals_vars = pvals[1:]
        return {
            "coefs": dict(zip(rhs_vars, beta_vars)),
            "se": dict(zip(rhs_vars, se_vars)),
            "pvals": dict(zip(rhs_vars, pvals_vars)),
            "r2": r2,
            "nobs": n,
            "ncountries": df["iso3"].nunique(),
            "method": "OLS",
        }
    except Exception as e:
        print(f"  ERROR OLS {dep_var}: {e}")
        return None


def star(p):
    if p < 0.01:
        return "***"
    elif p < 0.05:
        return "**"
    elif p < 0.10:
        return "*"
    return ""


def fmt_coef(res, var):
    if res is None or var not in res["coefs"]:
        return ""
    c = res["coefs"][var]
    s = res["se"][var]
    p = res["pvals"][var]
    return f"{c:.3f}{star(p)} ({s:.3f})"


def build_markdown_table(title, row_vars, col_labels, results, footer_rows):
    lines = [f"# {title}\n"]
    header = "| Variable | " + " | ".join(col_labels) + " |"
    sep = "|" + "---|" * (len(col_labels) + 1)
    lines.append(header)
    lines.append(sep)
    for var in row_vars:
        cells = [fmt_coef(r, var) for r in results]
        lines.append(f"| {var} | " + " | ".join(cells) + " |")
    lines.append(sep)
    for label, vals in footer_rows:
        lines.append(f"| {label} | " + " | ".join(vals) + " |")
    return "\n".join(lines)


# ── main ─────────────────────────────────────────────────────────────────────
def main():
    print("=" * 70)
    print("PHASE 8: POST-QE PREDICTION AND RE-EMERGENCE")
    print("=" * 70)

    panel = pd.read_csv(DATA_PATH)
    print(f"Panel: {len(panel):,} obs, {panel['iso3'].nunique()} countries")

    oecd = panel[panel["iso3"].isin(OECD_38)].copy()

    z_vars = ["Z_1", "Z_2", "Z_3"]
    controls_rate = ["rgdp_growth", "inflation", "fiscal_bal_gdp", "kaopen", "nfa_gdp_lag"]
    controls_infl = ["rgdp_growth", "output_gap", "fiscal_bal_gdp", "kaopen", "nfa_gdp_lag"]

    # ── TABLE 16: 2022-2024 Post-QE Regressions (OECD) ───────────────────
    print("\n── Table 16: 2022-2024 Post-QE Regressions (OECD) ──")

    oecd_post = oecd[(oecd["year"] >= 2022) & (oecd["year"] <= 2024)].copy()
    print(f"OECD 2022-2024: {len(oecd_post):,} obs, {oecd_post['iso3'].nunique()} countries")

    dvs_16 = [
        ("real_bond_10y", z_vars + controls_rate),
        ("real_short_3m", z_vars + controls_rate),
        ("inflation", z_vars + controls_infl),
    ]

    results_t16 = []
    col_labels_t16 = []
    methods_t16 = []

    for dv, rhs in dvs_16:
        # Try PanelGLS first (min_obs lowered for small sample)
        res = run_model(oecd_post, dv, rhs, min_obs=20)
        method = "PanelGLS"
        if res is None:
            # Fall back to OLS
            print(f"  Falling back to OLS for {dv}")
            res = run_ols(oecd_post, dv, rhs)
            method = "OLS"
        results_t16.append(res)
        col_labels_t16.append(dv)
        methods_t16.append(method)
        if res:
            z1_str = fmt_coef(res, "Z_1")
            print(f"  {dv} ({method}): Z_1={z1_str}, R2={res['r2']:.3f}, N={res['nobs']}")

    all_rhs = list(dict.fromkeys(z_vars + controls_rate + controls_infl))
    footer_t16 = [
        ("R2", [f"{r['r2']:.3f}" if r else "" for r in results_t16]),
        ("N obs", [f"{r['nobs']:,}" if r else "" for r in results_t16]),
        ("N countries", [f"{r['ncountries']}" if r else "" for r in results_t16]),
        ("Method", methods_t16),
    ]

    md_t16 = build_markdown_table(
        "Table 16: Post-QE Demographics and Rates (OECD 2022-2024)",
        all_rhs, col_labels_t16, results_t16, footer_t16,
    )
    note = ("\n\n*Note: Small sample (3 years). Results should be interpreted with caution. "
            "PanelGLS used where feasible; OLS fallback where insufficient variation.*")
    md_t16 += note
    print("\n" + md_t16)
    t16_path = TABLE_DIR / "phase8_table16_post_qe.md"
    t16_path.write_text(md_t16)
    print(f"\nSaved: {t16_path}")

    # ── TABLE 17: Out-of-Sample Test ──────────────────────────────────────
    print("\n── Table 17: Out-of-Sample Prediction ──")

    oecd_pre_gfc = oecd[oecd["year"] <= 2007].copy()
    oecd_2022_24 = oecd[(oecd["year"] >= 2022) & (oecd["year"] <= 2024)].copy()

    dv = "real_bond_10y"
    rhs = z_vars + controls_rate

    # Estimate on pre-GFC OECD
    res_pre = run_model(oecd_pre_gfc, dv, rhs)
    if res_pre is None:
        res_pre = run_ols(oecd_pre_gfc, dv, rhs)
        print("  Pre-GFC estimation: OLS fallback")

    if res_pre:
        print(f"  Pre-GFC model: R2={res_pre['r2']:.3f}, N={res_pre['nobs']}")
        for v in z_vars:
            print(f"    {v}: {fmt_coef(res_pre, v)}")

        # Predict on 2022-2024
        oos_df = oecd_2022_24[[dv] + rhs + ["iso3", "year"]].dropna()
        if len(oos_df) > 0:
            betas = np.array([res_pre["coefs"][v] for v in rhs])
            X_oos = oos_df[rhs].values
            y_oos = oos_df[dv].values
            y_pred = X_oos @ betas

            # Metrics
            rmse = np.sqrt(np.mean((y_oos - y_pred) ** 2))
            mae = np.mean(np.abs(y_oos - y_pred))
            corr = np.corrcoef(y_oos, y_pred)[0, 1] if len(y_oos) > 1 else np.nan

            # Naive prediction: 2007 OECD average
            avg_2007 = oecd[oecd["year"] == 2007][dv].dropna().mean()
            if np.isnan(avg_2007):
                # Fallback: pre-GFC average
                avg_2007 = oecd_pre_gfc[dv].dropna().mean()
            y_naive = np.full_like(y_oos, avg_2007)
            rmse_naive = np.sqrt(np.mean((y_oos - y_naive) ** 2))
            mae_naive = np.mean(np.abs(y_oos - y_naive))
            corr_naive = np.nan  # constant prediction has zero correlation

            print(f"\n  Out-of-sample results:")
            print(f"    Model:  RMSE={rmse:.3f}, MAE={mae:.3f}, corr={corr:.3f}")
            print(f"    Naive:  RMSE={rmse_naive:.3f}, MAE={mae_naive:.3f}")

            # Build table
            md_lines = ["# Table 17: Out-of-Sample Prediction (Pre-GFC Model on 2022-2024)\n"]
            md_lines.append("| Metric | Demographic Model | Naive (2007 avg) |")
            md_lines.append("|--------|------------------:|------------------:|")
            md_lines.append(f"| RMSE | {rmse:.3f} | {rmse_naive:.3f} |")
            md_lines.append(f"| MAE | {mae:.3f} | {mae_naive:.3f} |")
            md_lines.append(f"| Correlation | {corr:.3f} | -- |")
            md_lines.append(f"| N obs | {len(y_oos)} | {len(y_oos)} |")
            md_lines.append("")
            md_lines.append(f"*Pre-GFC estimation: {res_pre['nobs']} obs, "
                            f"R2={res_pre['r2']:.3f}. "
                            f"Naive benchmark: 2007 OECD average = {avg_2007:.2f}.*")

            md_t17 = "\n".join(md_lines)
            print("\n" + md_t17)
            t17_path = TABLE_DIR / "phase8_table17_oos.md"
            t17_path.write_text(md_t17)
            print(f"\nSaved: {t17_path}")
        else:
            print("  WARNING: No 2022-2024 OECD obs with complete data for OOS test")
    else:
        print("  WARNING: Pre-GFC estimation failed; skipping OOS test")

    # ── TABLE 18: Rolling Re-emergence ────────────────────────────────────
    print("\n── Table 18: Rolling Re-emergence ──")

    rolling_records = []
    for start_year in range(2005, 2016):  # 2005-2014 through 2015-2024
        end_year = start_year + 9  # 10-year windows
        window = oecd[(oecd["year"] >= start_year) & (oecd["year"] <= end_year)].copy()
        rhs_w = z_vars + controls_rate

        res = run_model(window, "real_bond_10y", rhs_w, min_obs=30)
        method = "PanelGLS"
        if res is None:
            res = run_ols(window, "real_bond_10y", rhs_w)
            method = "OLS"

        if res:
            rolling_records.append({
                "window": f"{start_year}-{end_year}",
                "start_year": start_year,
                "end_year": end_year,
                "Z_1_coef": res["coefs"]["Z_1"],
                "Z_1_se": res["se"]["Z_1"],
                "Z_1_pval": res["pvals"]["Z_1"],
                "r2": res["r2"],
                "nobs": res["nobs"],
                "ncountries": res["ncountries"],
                "method": method,
            })

    rolling_df = pd.DataFrame(rolling_records)

    if len(rolling_df) > 0:
        # Save CSV
        csv_path = TABLE_DIR / "phase8_reemergence_data.csv"
        rolling_df.to_csv(csv_path, index=False)
        print(f"Saved: {csv_path}")

        # Build markdown table
        md_lines = ["# Table 18: Rolling 10-Year Window Z_1 on Real 10-Year Bond Yield (OECD)\n"]
        md_lines.append("| Window | Z_1 coef | Z_1 SE | p-value | R2 | N | Method |")
        md_lines.append("|--------|--------:|-------:|--------:|---:|--:|--------|")
        for _, row in rolling_df.iterrows():
            s = star(row["Z_1_pval"])
            md_lines.append(
                f"| {row['window']} "
                f"| {row['Z_1_coef']:.3f}{s} "
                f"| {row['Z_1_se']:.3f} "
                f"| {row['Z_1_pval']:.4f} "
                f"| {row['r2']:.3f} "
                f"| {int(row['nobs']):,} "
                f"| {row['method']} |"
            )
        md_lines.append("")
        md_lines.append("*Windows are 10-year rolling, OECD only. "
                        "PanelGLS where feasible, OLS fallback.*")

        md_t18 = "\n".join(md_lines)
        print("\n" + md_t18)
        t18_path = TABLE_DIR / "phase8_table18_reemergence.md"
        t18_path.write_text(md_t18)
        print(f"\nSaved: {t18_path}")
    else:
        print("  WARNING: No rolling windows produced results")

    # ── TABLE 19: Country-Level Predicted vs Actual ───────────────────────
    print("\n── Table 19: Country-Level Predicted vs Actual ──")

    if res_pre is not None:
        betas = np.array([res_pre["coefs"][v] for v in rhs])

        # Average 2022-2024 by country
        oos_df = oecd_2022_24[["iso3", dv] + rhs].dropna()
        if len(oos_df) > 0:
            country_avg = oos_df.groupby("iso3").mean()

            country_records = []
            for iso3, row in country_avg.iterrows():
                actual = row[dv]
                X_row = np.array([row[v] for v in rhs])
                predicted = float(X_row @ betas)
                residual = actual - predicted
                country_records.append({
                    "Country": iso3,
                    "Predicted": predicted,
                    "Actual": actual,
                    "Residual": residual,
                })

            country_df = pd.DataFrame(country_records).sort_values("Residual")

            # Markdown table
            md_lines = ["# Table 19: Country-Level Predicted vs Actual Real 10y Yield "
                        "(2022-2024 avg)\n"]
            md_lines.append("| Country | Predicted | Actual | Residual |")
            md_lines.append("|---------|----------:|-------:|---------:|")
            for _, row in country_df.iterrows():
                md_lines.append(
                    f"| {row['Country']} "
                    f"| {row['Predicted']:.2f} "
                    f"| {row['Actual']:.2f} "
                    f"| {row['Residual']:.2f} |"
                )
            md_lines.append("")
            md_lines.append(f"*Predicted using pre-GFC (<=2007) OECD estimated coefficients. "
                            f"N = {len(country_df)} countries.*")

            md_t19 = "\n".join(md_lines)
            print("\n" + md_t19)
            t19_path = TABLE_DIR / "phase8_table19_country.md"
            t19_path.write_text(md_t19)
            print(f"\nSaved: {t19_path}")

            # ── Figure: Predicted vs Actual Scatter ───────────────────────
            print("\n── Figure: Predicted vs Actual Scatter ──")
            import matplotlib
            matplotlib.use("Agg")
            import matplotlib.pyplot as plt

            fig, ax = plt.subplots(figsize=(8, 8))
            ax.scatter(country_df["Predicted"], country_df["Actual"],
                       s=60, alpha=0.7, edgecolors="black", linewidths=0.5)

            # 45-degree line
            all_vals = np.concatenate([country_df["Predicted"].values,
                                       country_df["Actual"].values])
            lo, hi = np.nanmin(all_vals) - 1, np.nanmax(all_vals) + 1
            ax.plot([lo, hi], [lo, hi], "r--", linewidth=1.5, label="45-degree line")

            # Label each point
            for _, row in country_df.iterrows():
                ax.annotate(row["Country"],
                            (row["Predicted"], row["Actual"]),
                            fontsize=7, ha="center", va="bottom")

            ax.set_xlabel("Predicted Real 10y Yield (Pre-GFC model)")
            ax.set_ylabel("Actual Real 10y Yield (2022-2024 avg)")
            ax.set_title("Out-of-Sample: Predicted vs Actual Real 10-Year Bond Yield")
            ax.legend()
            ax.grid(True, alpha=0.3)
            ax.set_aspect("equal", adjustable="datalim")
            plt.tight_layout()

            fig_path = FIG_DIR / "phase8_predicted_vs_actual.png"
            plt.savefig(fig_path, dpi=150, bbox_inches="tight")
            plt.close()
            print(f"Saved: {fig_path}")
        else:
            print("  WARNING: No country-level data for 2022-2024")

    # ── Forward Projection: Demographic Natural Rate ──────────────────────
    print("\n── Forward Projection: Demographic Natural Rate ──")

    if res_pre is not None:
        z1_coef = res_pre["coefs"]["Z_1"]
        print(f"  Pre-GFC Z_1 coefficient: {z1_coef:.3f}")

        # Check if Z_1 projections go to 2050 in full_panel
        try:
            full_panel = pd.read_csv(FULL_PANEL_PATH)
            fp_oecd = full_panel[full_panel["iso3"].isin(OECD_38)].copy()
            max_year = fp_oecd["year"].max()
            print(f"  Full panel max year: {max_year}")

            if max_year >= 2050:
                proj_data = fp_oecd[fp_oecd["year"] <= 2050][["iso3", "year", "Z_1"]].dropna()
            else:
                # Use available data and note limitation
                proj_data = fp_oecd[["iso3", "year", "Z_1"]].dropna()
                print(f"  WARNING: Projections only available to {max_year}")
        except Exception as e:
            print(f"  Could not load full_panel for projections: {e}")
            # Fall back to monetary panel
            proj_data = oecd[["iso3", "year", "Z_1"]].dropna()

        if len(proj_data) > 0:
            # Compute demographic rate contribution = Z_1_coef * Z_1
            proj_data = proj_data.copy()
            proj_data["demo_rate_contribution"] = z1_coef * proj_data["Z_1"]

            # Average across OECD countries by year
            annual_avg = proj_data.groupby("year")["demo_rate_contribution"].agg(
                ["mean", "std", "count"]
            ).reset_index()
            annual_avg.columns = ["year", "mean_contribution", "std_contribution", "n_countries"]

            # Restrict to years with reasonable coverage
            annual_avg = annual_avg[annual_avg["n_countries"] >= 5].copy()

            import matplotlib
            matplotlib.use("Agg")
            import matplotlib.pyplot as plt

            fig, ax = plt.subplots(figsize=(12, 6))

            years = annual_avg["year"].values
            mean_c = annual_avg["mean_contribution"].values
            std_c = annual_avg["std_contribution"].values

            ax.plot(years, mean_c, "b-", linewidth=2, label="OECD avg demographic rate contribution")
            ax.fill_between(years, mean_c - std_c, mean_c + std_c,
                            alpha=0.2, color="blue", label="+/- 1 std")

            # Mark historical vs projection
            hist_mask = years <= 2024
            proj_mask = years > 2024
            if np.any(proj_mask):
                ax.axvline(2024, color="gray", linestyle="--", linewidth=1,
                           label="Projection starts")

            ax.axhline(0, color="black", linewidth=0.5)
            ax.set_xlabel("Year")
            ax.set_ylabel("Demographic Rate Contribution (pp)")
            ax.set_title("Demographic Natural Rate Contribution (OECD Average)\n"
                         f"Based on pre-GFC Z_1 coefficient = {z1_coef:.2f}")
            ax.legend(fontsize=9)
            ax.grid(True, alpha=0.3)
            plt.tight_layout()

            fig_path = FIG_DIR / "phase8_natural_rate_projection.png"
            plt.savefig(fig_path, dpi=150, bbox_inches="tight")
            plt.close()
            print(f"Saved: {fig_path}")

    # ── Summary ──────────────────────────────────────────────────────────
    print("\n── Key Findings ──")

    # Table 16 summary
    for i, (dv_name, _) in enumerate(dvs_16):
        r = results_t16[i]
        if r:
            z1_c = r["coefs"].get("Z_1", np.nan)
            z1_p = r["pvals"].get("Z_1", np.nan)
            print(f"  Post-QE {dv_name}: Z_1={z1_c:.3f} (p={z1_p:.3f}) [{methods_t16[i]}]")

    # Rolling re-emergence
    if len(rolling_df) > 0:
        first_w = rolling_df.iloc[0]
        last_w = rolling_df.iloc[-1]
        print(f"  Rolling Z_1: {first_w['window']}={first_w['Z_1_coef']:.3f} "
              f"(p={first_w['Z_1_pval']:.3f}) -> "
              f"{last_w['window']}={last_w['Z_1_coef']:.3f} "
              f"(p={last_w['Z_1_pval']:.3f})")
        # Is it re-emerging?
        if last_w["Z_1_pval"] < 0.10 and first_w["Z_1_pval"] > 0.10:
            print("  -> EVIDENCE of re-emergence: Z_1 became significant in recent windows")
        elif last_w["Z_1_pval"] < first_w["Z_1_pval"]:
            print("  -> Z_1 p-value declining (trending toward significance)")
        else:
            print("  -> No clear re-emergence pattern")

    print("\nPhase 8 complete.")


if __name__ == "__main__":
    main()
